import torch
import torch.nn.functional as F
import random

from torch import Tensor
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset

from .complex import Cochain, Complex, CochainBatch, ComplexBatch
from .Augmentation import Augmentation
from .InverseSampling import InverseSampling
from .Propagation import Propagation
from .PositionAwareEncoder import PositionAwareEncoder
from .SimilarityFunctions import SimilarityFunctions
from .utility import process_tu_dataset
from .TaskDecoder import TaskDecoder
from .extract_ring import extract_ring_mean,ring_contrastive_loss

class ToyGraphBase:
    def __init__(self, pretrain_model, num_class, emb_size, query_graph_hop):
        self.pretrain_model = pretrain_model
        self.num_class = num_class
        self.emb_size = emb_size
        self.query_graph_hop = query_graph_hop
        self.toy_graph_hop = query_graph_hop
        self.retrieve_num = 5
        self.num_augment_scale = 1
        self.structure_weight = 0.5
        self.semantic_weight = 0.4
        self.ring_weight = 0.1
        self.num_anchors = 16
        self.dis_q = 0.1
        # 以下变量需要在构图后赋值
        self.resource_keys = None
        self.resource_values = None
        self.resource_labels = None
        self.resource_positions = None
        self.resource_ring_feats = None

        # print("[DEBUG INIT] retrieve_num (topK):", self.retrieve_num)
        # print("[DEBUG INIT] toy_graph_hop (hopK):", self.toy_graph_hop)
    def build_toy_graph(self, resource_dataset: TUDataset):
        self.resource_keys = []
        self.resource_values = []
        self.resource_labels = []
        self.resource_positions = []
        self.resource_complexes = []
        self.resource_ring_feats = []

        num_node_attributes = resource_dataset.num_node_attributes
        resource_loader = DataLoader(resource_dataset, batch_size=1, shuffle=False)

        for data in resource_loader:
            features, adj, graph_labels, complex_batch, _  = process_tu_dataset(data, self.num_class, num_node_attributes)
            features = features.cuda()
            adj = adj.cuda()
            label = F.one_hot(graph_labels, num_classes=self.num_class).float()
            for aug_features, aug_adj in Augmentation.augment_graph(self.num_augment_scale, features, adj):
                self._build_toy_graph_base(aug_features, aug_adj, label, complex_batch)

        self.resource_keys = torch.cat(self.resource_keys, dim=0)
        self.resource_values = torch.cat(self.resource_values, dim=0)
        self.resource_labels = torch.cat(self.resource_labels, dim=0)
        self.resource_positions = torch.cat(self.resource_positions, dim=0)
        self.resource_ring_feats = torch.cat(self.resource_ring_feats, dim=0)

        # self.resource_complex = ComplexBatch.from_complex_list(self.resource_complexes, max_dim=2)

        # print("[ToyGraphBase] resource_keys:", self.resource_keys.shape)
        # print("[ToyGraphBase] resource_positions:", self.resource_positions.shape)
        # print("[ToyGraphBase] resource_labels:", self.resource_labels.shape)

    def _build_toy_graph_base(self, features, adj, label, complex_batch):
        device = next(self.pretrain_model.parameters()).device

        features = features.to(device).float()
        adj = adj.to(device).float()
        label = label.to(device)
        node_emb, _ = self.pretrain_model.embed(features, adj)
        # === Graph-level embedding ===
        graph_embed = node_emb.mean(dim=0, keepdim=True)

        if graph_embed.dim() == 3:
            graph_embed = graph_embed.mean(dim=1)
        if graph_embed.dim() == 2 and graph_embed.size(0) > 1:
            graph_embed = graph_embed.mean(dim=0, keepdim=True)
        elif graph_embed.dim() == 1:
            graph_embed = graph_embed.unsqueeze(0)
        ring_mean = extract_ring_mean(node_emb, complex_batch)
        # ring_mean = ring_mean.detach()
        # print("[DEBUG] ring_mean.requires_grad:", ring_mean.requires_grad)

        ring_mean = ring_mean.unsqueeze(0).expand(graph_embed.size(0), -1)
        enhanced_emb = torch.cat([graph_embed, ring_mean], dim=-1)
        key = F.normalize(enhanced_emb, p=2, dim=-1)

        #  K-hop 聚合
        use_khop = True
        if use_khop and self.toy_graph_hop > 0:
            #  一定用节点特征，不要用 key
            value = Propagation.aggregate_k_hop_features(adj, features, self.toy_graph_hop)  # [N,D]
            value = value.mean(dim=0, keepdim=True)  # [1,D]
        else:
            value = graph_embed
            # 拼接 ring_mean

            #  保证 ring_mean 至少是 [1, D]
        if ring_mean.dim() == 1:
            ring_mean = ring_mean.unsqueeze(0)

            #  保证 ring_mean 的最后一维是 D
        if ring_mean.size(1) != node_emb.size(1):
            print(f"[DEBUG] Fix ring_mean shape: {ring_mean.shape} → [:,{node_emb.size(1)}]")
            ring_mean = ring_mean[:, :node_emb.size(1)]

        assert ring_mean.size(1) == node_emb.size(1), f"[ToyGraphBase] ring_mean final wrong shape: {ring_mean.shape}"

        # 同理 value 聚合后必须是 [1, D]
        if value.size(1) != node_emb.size(1):
            # print(f"[DEBUG] value shape auto-fix: {value.shape} → fallback to graph_embed")
            value = graph_embed.clone()

        assert value.size(1) == node_emb.size(1), f"[ToyGraphBase] value final wrong shape: {value.shape}"

        # 拼接: 保证 [1, 2D]
        value = torch.cat([value, ring_mean], dim=-1)

        assert value.size(1) == 2 * node_emb.size(1), f"[ToyGraphBase] value shape after concat wrong: {value.shape}"

        # value = torch.cat([value, ring_mean], dim=-1)  # 拼接环增强
        pos = PositionAwareEncoder.encode_position_aware_code(adj, self.num_anchors, self.dis_q)
        pos = pos.mean(dim=0, keepdim=True)  # [1,num_anchors]
        key = key.detach().clone()
        value = value.detach().clone()

        # print("[DEBUG] ToyGraphBase key.requires_grad:", key.requires_grad)
        # print("[DEBUG] ToyGraphBase value.requires_grad:", value.requires_grad)

        self.resource_keys.append(key)
        self.resource_values.append(value)
        self.resource_labels.append(label)
        self.resource_positions.append(pos)
        self.resource_ring_feats.append(ring_mean.detach().clone())

    def retrieve(self, search_keys: Tensor, search_adj: Tensor, complex_batch, search_ring_feat: Tensor, add_noise: bool):
        B = search_keys.size(0)

        # === semantic similarity ===
        semantic_sim = SimilarityFunctions.calculate_cosine_similarity(
            search_keys, self.resource_keys
        )  # [B, num_toy]

        # === structure similarity ===
        if search_adj.dim() == 2:
            pos = PositionAwareEncoder.encode_position_aware_code(
                search_adj, self.num_anchors, self.dis_q
            )
            pos = pos.mean(dim=0, keepdim=True)
            search_positions = pos.repeat(B, 1) if B > 1 else pos
        elif search_adj.dim() == 3:
            search_positions = PositionAwareEncoder.encode_batch_graph_signature(
                search_adj, self.num_anchors, self.dis_q
            )
        else:
            raise ValueError(f"[ERROR] search_adj must be [N,N] or [B,N,N], got {search_adj.shape}")

        structure_sim = SimilarityFunctions.calculate_cosine_similarity(
            search_positions, self.resource_positions
        )  # [B, num_toy]

        # === ring similarity ===
        ring_sim = SimilarityFunctions.calculate_cosine_similarity(
            search_ring_feat, self.resource_ring_feats
        )  # [B, num_toy]

        # === combine similarity ===
        similarity_weights = torch.tensor(
            [self.structure_weight, self.semantic_weight, self.ring_weight],
            device=semantic_sim.device
        )
        similarity_matrices = torch.stack([structure_sim, semantic_sim, ring_sim], dim=0)
        similarity_scores = torch.einsum('i,ijk->jk', similarity_weights, similarity_matrices)  # [B, num_toy]

        retrieve_num = 2 * self.retrieve_num if add_noise else self.retrieve_num
        topk_scores, topk_indices = torch.topk(similarity_scores, retrieve_num, largest=True, sorted=True)
        rag_weights = torch.softmax(topk_scores, dim=-1)  # [B, K]

        rag_embeddings = self.resource_values[topk_indices]  # [B, K, D]
        rag_labels = self.resource_labels[topk_indices]  # [B, K, num_class]
        if add_noise:
            rag_embeddings = self._add_noise(rag_embeddings)
        # print("[DEBUG] rag_labels topK one-hot:", rag_labels.shape)
        # print("[DEBUG] rag_labels (first sample):", rag_labels[0])  # [K, num_class]
        # print("[DEBUG] rag_labels unique (first sample):", torch.unique(rag_labels[0], dim=0))

        return rag_embeddings, rag_labels, rag_weights


    def show(self):
        print("[ToyGraphBase Summary]")
        print("resource_keys:", self.resource_keys.shape)
        print("resource_positions:", self.resource_positions.shape)
        print("resource_labels:", self.resource_labels.shape)
        # print("resource_complex (batch):", self.resource_complex)

        N = self.resource_keys.size(0)
        sample_idx = torch.randint(0, N, (500, 2), device=self.resource_keys.device)
        vec1 = self.resource_keys[sample_idx[:, 0]]
        vec2 = self.resource_keys[sample_idx[:, 1]]
        cos_sim = F.cosine_similarity(vec1, vec2, dim=-1)

        print("Sampled Proto Cosine mean:", cos_sim.mean().item(),
              "min:", cos_sim.min().item(),
              "max:", cos_sim.max().item())

    # def graph_augmentation(self, features, adj):
    #     sample_prob = InverseSampling.compute_sample_prob(adj)
    #     aug_features = Augmentation.augment_features(features, sample_prob)
    #     aug_adj = Augmentation.augment_adj(adj, sample_prob)
    #
    #     # 放这里是合法的！
    #     sample_keys = torch.mean(sample_keys, dim=0).unsqueeze(0)
    #     sample_values = torch.mean(sample_values, dim=0).unsqueeze(0)
    #     sample_labels = F.one_hot(graph_label, num_classes=self.resource_labels.shape[1]).cuda()
    #
    #     self.resource_keys = torch.cat((self.resource_keys, sample_keys), dim=0)
    #     self.resource_values = torch.cat((self.resource_values, sample_values), dim=0)
    #     self.resource_labels = torch.cat((self.resource_labels, sample_labels), dim=0)
    #
    #     return aug_features, aug_adj

    def _add_noise(self, embeddings):
        noise_std = self.noise_std
        noise = torch.randn_like(embeddings) * noise_std
        noisy_embeddings = embeddings + noise

        # with torch.no_grad():
        #     print("[DEBUG _add_noise] noise std:", noise_std)
        #     print("[DEBUG _add_noise] embeddings mean:", embeddings.mean().item())
        #     print("[DEBUG _add_noise] noise mean:", noise.mean().item())

        return noisy_embeddings

